In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch

SIIM-ACR Pneumothorax Segmentation.

Data loader.

Training set.

In [2]:
from pydicom import dcmread
from torch.utils.data import Dataset
from lib.mask_functions import rle2mask

class TrainingSet(Dataset):
    
    def __init__(self):
        
        self.df = pd.read_csv('./dataset/train.csv')
        
    def __getitem__(self, idx):
        
        imageId, rle = self.df['ImageId'][idx], self.df['EncodedPixels'][idx]
        
        dcm = dcmread('./dataset/dicom-images-train/{}.dcm'.format(imageId))
        
        image, height, width = dcm.pixel_array, dcm.Rows, dcm.Columns
        
        mask = rle2mask(rle, width, height).T if rle != '-1' else np.zeros(width * height).reshape(width, height)
        
        return np.array([(image / 255)]), np.array([(mask / 255)])
        
    def __len__(self):
        
        return len(self.df)

def training_samples(n = 4, m = 6):
    
    fig, axes = plt.subplots(n, m, figsize = (m * 5, n * 5))
    
    idx = 0
    
    for id, (image, mask) in enumerate(TrainingSet()):
        
        if mask.max() > 0:
    
            axes[(idx // m), (idx % m)].set_title('Train {}'.format(id))
        
            axes[(idx // m), (idx % m)].imshow(image[0], cmap = plt.cm.bone)
            axes[(idx // m), (idx % m)].imshow(mask[0], alpha = 0.3, cmap = 'Reds')
            
            idx = idx + 1
            
            if idx == n * m:
                break

training_samples()

Test set.

In [3]:
class TestSet(Dataset):
    
    def __init__(self):
        
        self.df = pd.read_csv('./dataset/test.csv')
        
    def __getitem__(self, idx):
        
        imageId = self.df['ImageId'][idx]
        
        dcm = dcmread('./dataset/dicom-images-test/{}.dcm'.format(imageId))
        
        image = dcm.pixel_array
        
        return np.array([image / 255])
        
    def __len__(self):
        
        return len(self.df)
    
def test_samples(n = 4, m = 6):
    
    fig, axes = plt.subplots(n, m, figsize = (m * 5, n * 5))
    
    for idx, image in enumerate(TestSet()):
    
        axes[(idx // m), (idx % m)].set_title('Test {}'.format(idx))
        axes[(idx // m), (idx % m)].imshow(image[0], cmap = plt.cm.bone)
        
        if idx == (n * m - 1):
            break

test_samples()

Augmentation.

TODO.

UNet model.

In [4]:
from torch.nn import Module, Sequential, Conv2d, ReLU, MaxPool2d, Upsample

class UNet(Module):
    
    def double_conv(self, in_channels, out_channels):
        
        return Sequential(
            Conv2d(in_channels, out_channels, 3, padding = 1),
            ReLU(inplace = True),
            Conv2d(out_channels, out_channels, 3, padding = 1),
            ReLU(inplace = True))
    
    def upconv(self, in_channels, out_channels):
        
        return Sequential(
            Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True),
            Conv2d(in_channels, out_channels, 3, padding = 1))
        
    def __init__(self):
        
        super(UNet, self).__init__()
        
        self.dconv1 = self.double_conv(1, 64)
        self.dconv2 = self.double_conv(64, 128)
        self.dconv3 = self.double_conv(128, 256)
        self.dconv4 = self.double_conv(256, 512)
        
        self.dconv5 = self.double_conv(512, 1024)
        
        self.dconv6 = self.double_conv(1024, 512)
        self.dconv7 = self.double_conv(512, 256)
        self.dconv8 = self.double_conv(256, 128)
        self.dconv9 = self.double_conv(128, 64)
        
        self.up1 = self.upconv(1024, 512)
        self.up2 = self.upconv(512, 256)
        self.up3 = self.upconv(256, 128)
        self.up4 = self.upconv(128, 64)
        
        self.down = MaxPool2d(2)
        
        self.fc = Conv2d(64, 2, 1)
        
    def forward(self, x):
        
        conv1 = self.dconv1(x)
        x = self.down(conv1)
        
        conv2 = self.dconv2(x)
        x = self.down(conv2)
        
        conv3 = self.dconv3(x)
        x = self.down(conv3)
        
        conv4 = self.dconv4(x)
        x = self.down(conv4)
        
        x = self.dconv5(x)
        
        x = self.up1(x)
        x = self.dconv6(torch.cat([conv4, x], dim = 1))
        
        x = self.up2(x)
        x = self.dconv7(torch.cat([conv3, x], dim = 1))
        
        x = self.up3(x)
        x = self.dconv8(torch.cat([conv2, x], dim = 1))
        
        x = self.up4(x)
        x = self.dconv9(torch.cat([conv1, x], dim = 1))
        
        return self.fc(x)

Dice loss.

$loss = 1 - \frac{2 \lvert X \cup Y \rvert}{\lvert X \rvert + \lvert Y \rvert}$

In [5]:
from torch.nn.functional import binary_cross_entropy_with_logits, sigmoid

def dice_loss(outputs, masks, smooth = 1):
    
    outputs = outputs.contiguous()
    masks = masks.contiguous()
    
    s1 = (outputs * masks).sum(dim = 2).sum(dim = 2)
    s2 = outputs.sum(dim = 2).sum(dim = 2)
    s3 = masks.sum(dim = 2).sum(dim = 2)
    
    loss = 1 - ((2 * s1 + smooth) / (s2 + s3 + smooth))
    
    return loss.mean()

Training.

In [ ]:
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(model, num_epoch = 10, batch_size = 1):
    
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device('cpu')
    
    model = model.to(device)
    
    data_loader = DataLoader(TrainingSet(), batch_size = batch_size)
    
    optimizer = Adam(model.parameters())
    
    with tqdm(total = num_epoch * len(data_loader)) as pbar:
        for epoch in range(num_epoch):
            for images, masks in DataLoader(TrainingSet(), batch_size = batch_size):

                images, masks = images.to(device), masks.to(device)
                
                outputs = model(images.float())
                loss = dice_loss(outputs, masks)
                
                pbar.set_postfix(loss = loss.cpu().item())
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update()
                
    return model

model = train(UNet(), num_epoch = 1, batch_size = 1)